from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging

import numpy
import torch
from facenet_pytorch import InceptionResnetV1
from torchvision.transforms import transforms

from .face_model import FaceModel
from .singleton import Singleton

LOG = logging.getLogger(__name__)


@Singleton
class FaceNetTorch(FaceModel):

    def __init__(self):
        super().__init__()
        self.model = None
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.create_model()

    def init_app(self, app):
        pass

    def create_model(self):
        # set TORCH_HOME in your os environment
        self.model = InceptionResnetV1(pretrained='vggface2')
        if torch.cuda.is_available():
            self.model.cuda()
        self.model.eval()

    def get_embeddings(self, x: numpy.ndarray):
        x = self.transform(x)
        x = x.unsqueeze(0)
        if torch.cuda.is_available():
            x = x.cuda()
        embeddings = self.model(x)
        embeddings = embeddings.detach().cpu().numpy()
        return embeddings[0]

    def get_prediction(self, x: numpy.ndarray, **kwargs) -> numpy.ndarray:
        pass

    def get_predictions(self, x: numpy.ndarray, **kwargs) -> (float, numpy.ndarray):
        pass
